{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preparation \n",
    "## Append to path and import packages\n",
    "In case gumpy is not installed as package, you may have to specify the path to the gumpy directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reset\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt \n",
    "import sys, os, os.path\n",
    "sys.path.append('/.../gumpy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import gumpy\n",
    "This may take a while, as gumpy as several dependencies that will be loaded automatically"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import gumpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Select workflow\n",
    "Select the actions you want to perform"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import data\n",
    "To import data, you have to specify the directory in which your data is stored in. For the example given here, the data is in the subfolder ``../EEG-Data/Graz_data/data``. \n",
    "Then, one of the classes that subclass from ``dataset`` can be used to load the data. In the example, we will use the GrazB dataset, for which ``gumpy`` already includes a corresponding class. If you have different data, simply subclass from ``gumpy.dataset.Dataset``."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First specify the location of the data and some \n",
    "# identifier that is exposed by the dataset (e.g. subject)\n",
    "\n",
    "data_base_dir = '/.../.../Data'\n",
    "\n",
    "grazb_base_dir = os.path.join(data_base_dir, 'Graz')\n",
    "subject = 'B01'\n",
    "\n",
    "# The next line first initializes the data structure. \n",
    "# Note that this does not yet load the data! In custom implementations\n",
    "# of a dataset, this should be used to prepare file transfers, \n",
    "# for instance check if all files are available, etc.\n",
    "grazb_data = gumpy.data.GrazB(grazb_base_dir, subject)\n",
    "\n",
    "# Finally, load the dataset\n",
    "grazb_data.load()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The abstract class allows to print some information about the contained data. This is a commodity function that allows quick inspection of the data as long as all necessary fields are provided in the subclassed variant."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grazb_data.print_stats()\n",
    "# labels = grazb_data.labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Postprocess data\n",
    "Usually it is necessary to postprocess the raw data before you can properly use it. ``gumpy`` provides several methods to easily do so, or provides implementations that can be adapted to your needs.\n",
    "\n",
    "Most methods internally use other Python toolkits, for instance ``sklearn``, which is heavily used throughout ``gumpy``. Thereby, it is easy to extend ``gumpy`` with custom filters. In addition, we expect users to have to manipulate the raw data directly as shown in the following example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Common average re-referencing the data  to Cz\n",
    "Some data is required to be re-referenced to a certain electrode. Because this may depend on your dataset, there is no common function provided by ``gumpy`` to do so. However and if sub-classed according to the documentation, you can access the raw-data directly as in the following example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if False:\n",
    "    grazb_data.raw_data[:, 0] -= 2 * grazb_data.raw_data[:, 1]\n",
    "    grazb_data.raw_data[:, 2] -= 2 * grazb_data.raw_data[:, 2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Notch and Band-Pass Filters\n",
    "``gumpy`` ships with several filters already implemented. They accept either raw data to be filtered, or a subclass of ``Dataset``. In the latter case, ``gumpy`` will automatically convert all channels using parameters extracted from the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this returns a butter-bandpass filtered version of the entire dataset\n",
    "btr_data = gumpy.signal.butter_bandpass(grazb_data, lo=1, hi=35)\n",
    "\n",
    "# it is also possible to use filters on individual electrodes using \n",
    "# the .raw_data field of a dataset. The example here will remove a certain\n",
    "# from a single electrode using a Notch filter. This example also demonstrates\n",
    "# that parameters will be forwarded to the internal call to the filter, in this\n",
    "# case the scipy implementation iirnotch (Note that iirnotch is only available\n",
    "# in recent versions of scipy, and thus disabled in this example by default)\n",
    "\n",
    "# frequency to be removed from the signal\n",
    "if False:\n",
    "    f0 = 50 \n",
    "    # quality factor\n",
    "    Q = 50 \n",
    "    # get the cutoff frequency\n",
    "    w0 = f0/(grazb_data.sampling_freq/2) \n",
    "    # apply the notch filter\n",
    "    btr_data = gumpy.signal.notch(btr_data[:, 0], w0, Q)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Normalization\n",
    "Many datasets require normalization. ``gumpy`` provides functions to normalize either using a mean computation or via min/max computation. As with the filters, this function accepts either an instance of ``Dataset``, or raw_data. In fact, it can be used for postprocessing any row-wise data in a numpy matrix."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if False:\n",
    "    # normalize the data first\n",
    "    norm_data = gumpy.signal.normalize(grazb_data, 'mean_std')\n",
    "    # let's see some statistics\n",
    "    print(\"\"\"Normalized Data:\n",
    "      Mean    = {:.3f}\n",
    "      Min     = {:.3f}\n",
    "      Max     = {:.3f}\n",
    "      Std.Dev = {:.3f}\"\"\".format(\n",
    "      np.nanmean(norm_data),np.nanmin(norm_data),np.nanmax(norm_data),np.nanstd(norm_data)\n",
    "    ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plotting and Feature Extraction\n",
    "\n",
    "Certainly you wish to plot results. ``gumpy`` provides several functions that show how to implement visualizations. For this purpose it heavily relies on ``matplotlib``, ``pandas``, and ``seaborn``. The following examples will show several of the implemented signal processing methods as well as their corresponding plotting functions. Moreover, the examples will show you how to extract features\n",
    "\n",
    "That said, let's start with a simple visualization where we access the filtered data from above to show you how to access the data and plot it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "Plot after filtering with a butter bandpass (ignore normalization)\n",
    "plt.figure()\n",
    "plt.clf()\n",
    "plt.plot(btr_data[grazb_data.trials[0]: grazb_data.trials[1], 0], label='C3')\n",
    "plt.plot(btr_data[grazb_data.trials[0]: grazb_data.trials[1], 1], alpha=0.7, label='C4')\n",
    "plt.plot(btr_data[grazb_data.trials[0]: grazb_data.trials[1], 2], alpha=0.7, label='Cz')\n",
    "plt.legend()\n",
    "plt.title(\" Filtered Data\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## EEG band visualization\n",
    "Using ``gumpy``'s filters and the provided method, it is easy to filter and subsequently plot the  EEG bands of a trial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# determine the trial that we wish to plot\n",
    "n_trial = 120\n",
    "# now specify the alpha and beta cutoff frequencies\n",
    "lo_a, lo_b = 7, 16\n",
    "hi_a, hi_b = 13, 24\n",
    "\n",
    "# first step is to filter the data\n",
    "flt_a = gumpy.signal.butter_bandpass(grazb_data, lo=lo_a, hi=hi_a)\n",
    "flt_b = gumpy.signal.butter_bandpass(grazb_data, lo=lo_b, hi=hi_b)\n",
    "\n",
    "# finally we can visualize the data\n",
    "gumpy.plot.EEG_bandwave_visualizer(grazb_data, flt_a, n_trial, lo_a, hi_a)\n",
    "gumpy.plot.EEG_bandwave_visualizer(grazb_data, flt_b, n_trial, lo_a, hi_a)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extract trials\n",
    "Now we wish to extract the trials from the data. This operation may heavily depend on your dataset, and thus we cannot guarantee that the function works for your specific dataset. However, the used function ``gumpy.utils.extract_trials`` can be used as a guideline how to extract the trials you wish to examine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# retrieve the trials from the filtered data. This requires that the function\n",
    "# knows the number of trials, labels, etc. when only passed a (filtered) data matrix\n",
    "trial_marks = grazb_data.trials\n",
    "labels = grazb_data.labels\n",
    "sampling_freq = grazb_data.sampling_freq\n",
    "epochs = gumpy.utils.extract_trials(grazb_data, trials=trial_marks, labels=labels, sampling_freq=sampling_freq)\n",
    "# it is also possible to pass an instance of Dataset and filtered data.\n",
    "# gumpy will then infer all necessary details from the dataset\n",
    "data_class_b = gumpy.utils.extract_trials(grazb_data, flt_b)\n",
    "\n",
    "# similar to other functions, this one allows to pass an entire instance of Dataset\n",
    "# to operate on the raw data\n",
    "data_class1 = gumpy.utils.extract_trials(grazb_data)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize the classes\n",
    "Given the extracted trials from above, we can proceed to visualize the average power of a class. Again, this depends on the specific data and thus you may have to adapt the function accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify some cutoff values for the visualization\n",
    "lowcut_a, highcut_a = 14, 30\n",
    "# and also an interval to display\n",
    "interval_a = [0, 8]\n",
    "# visualize logarithmic power?\n",
    "logarithmic_power = False\n",
    "\n",
    "# visualize the extracted trial from above\n",
    "gumpy.plot.average_power(epochs, lowcut_a, highcut_a, interval_a, grazb_data.sampling_freq, logarithmic_power)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Wavelet transform\n",
    "``gumpy`` relies on ``pywt`` to compute wavelet transforms. Furthermore, it contains convenience functions to visualize the results of the discrete wavelet transform as shown in the example below for the Graz dataset and the classes extracted above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# As with most functions, you can pass arguments to a \n",
    "# gumpy function that will be forwarded to the backend.\n",
    "# In this example the decomposition levels are mandatory, and the \n",
    "# mother wavelet that should be passed is optional\n",
    "level = 6\n",
    "wavelet = 'db4'\n",
    "\n",
    "# now we can retrieve the dwt for the different channels\n",
    "mean_coeff_ch0_c1 = gumpy.signal.dwt(data_class1[0], level=level, wavelet=wavelet)\n",
    "mean_coeff_ch1_c1 = gumpy.signal.dwt(data_class1[1], level=level, wavelet=wavelet)\n",
    "mean_coeff_ch0_c2 = gumpy.signal.dwt(data_class1[3], level=level, wavelet=wavelet)\n",
    "mean_coeff_ch1_c2 = gumpy.signal.dwt(data_class1[4], level=level, wavelet=wavelet)\n",
    "\n",
    "# gumpy's signal.dwt function returns the approximation of the \n",
    "# coefficients as first result, and all the coefficient details as list\n",
    "# as second return value (this is contrast to the backend, which returns\n",
    "# the entire set of coefficients as a single list)\n",
    "approximation_C3 = mean_coeff_ch0_c2[0]\n",
    "approximation_C4 = mean_coeff_ch1_c2[0]\n",
    "\n",
    "# as mentioned in the comment above, the list of details are in the second\n",
    "# return value of gumpy.signal.dwt. Here we save them to additional variables\n",
    "# to improve clarity\n",
    "details_c3_c1 = mean_coeff_ch0_c1[1]\n",
    "details_c4_c1 = mean_coeff_ch1_c1[1]\n",
    "details_c3_c2 = mean_coeff_ch0_c2[1]\n",
    "details_c4_c2 = mean_coeff_ch1_c2[1]\n",
    "\n",
    "# gumpy exhibits a function to plot the dwt results. You must pass three lists,\n",
    "# i.e. the labels of the data, the approximations, as well as the detailed coeffs,\n",
    "# so that gumpy can automatically generate appropriate titles and labels.\n",
    "# you can pass an additional class string that will be incorporated into the title.\n",
    "# the function returns a matplotlib axis object in case you want to further\n",
    "# customize the plot.\n",
    "gumpy.plot.dwt(\n",
    "    [approximation_C3, approximation_C4],\n",
    "   [details_c3_c1, details_c4_c1],\n",
    "   ['C3, c1', 'C4, c1'],    level, grazb_data.sampling_freq, 'Class: Left')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DWT reconstruction and visualization\n",
    "Often a user wantes to reconstruct the power spectrum of a dwt and visualize the results. The functions will return a list of all the reconstructed signals as well as a handle to the figure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gumpy.plot.reconstruct_without_approx(\n",
    "    [details_c3_c2[4], details_c4_c2[4]], \n",
    "    ['C3-c1', 'C4-c1'], level=6)\n",
    "\n",
    "gumpy.plot.reconstruct_without_approx(\n",
    "    [details_c3_c1[5], details_c4_c1[5]], \n",
    "    ['C3-c1', 'C4-c1'], level=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gumpy.plot.reconstruct_with_approx(\n",
    "    [details_c3_c1[5], details_c4_c1[5]],\n",
    "    ['C3', 'C4'], wavelet=wavelet)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Welch's Power Spectral Density estimate\n",
    "Estimating the power spectral density according to Welch's method is =imilar to the power reconstruction shown above"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the function gumpy.plot.welch_psd returns the power densities as \n",
    "# well as a handle to the figure. You can also pass a figure in if you \n",
    "# wish to modify the plot\n",
    "fig = plt.figure()\n",
    "plt.title('Customized plot')\n",
    "ps, fig = gumpy.plot.welch_psd(\n",
    "    [details_c3_c1[4], details_c3_c2[4]],\n",
    "    ['C3 - c1', 'C4 - c1'],\n",
    "    grazb_data.sampling_freq, fig=fig)\n",
    "\n",
    "ps, fig = gumpy.plot.welch_psd(\n",
    "    [details_c4_c1[4], details_c4_c2[4]],\n",
    "    ['C3 - c1', 'C4 - c1'],\n",
    "    grazb_data.sampling_freq)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  Alpha and Beta sub-bands\n",
    "Using gumpys functions you can quickly define feature extractors. The following examples will demonstrate how you can use the predefined filters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def alpha_subBP_features(data):\n",
    "    # filter data in sub-bands by specification of low- and high-cut frequencies\n",
    "    alpha1 = gumpy.signal.butter_bandpass(data, 8.5, 11.5, order=6)\n",
    "    alpha2 = gumpy.signal.butter_bandpass(data, 9.0, 12.5, order=6)\n",
    "    alpha3 = gumpy.signal.butter_bandpass(data, 9.5, 11.5, order=6)\n",
    "    alpha4 = gumpy.signal.butter_bandpass(data, 8.0, 10.5, order=6)\n",
    "\n",
    "    # return a list of sub-bands\n",
    "    return [alpha1, alpha2, alpha3, alpha4]\n",
    "\n",
    "alpha_bands = np.array(alpha_subBP_features(btr_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def beta_subBP_features(data):\n",
    "    beta1 = gumpy.signal.butter_bandpass(data, 14.0, 30.0, order=6)\n",
    "    beta2 = gumpy.signal.butter_bandpass(data, 16.0, 17.0, order=6)\n",
    "    beta3 = gumpy.signal.butter_bandpass(data, 17.0, 18.0, order=6)\n",
    "    beta4 = gumpy.signal.butter_bandpass(data, 18.0, 19.0, order=6)\n",
    "    return [beta1, beta2, beta3, beta4]\n",
    "\n",
    "beta_bands = np.array(beta_subBP_features(btr_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extract features without considering class information\n",
    "The following examples show how the sub-bands can be used to extract features. This also shows how the fields of the dataset can be accessed, and how to write methods specific to your data using a mix of gumpy's and numpy's functions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Method 1: logarithmic sub-band power"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def powermean(data, trial, fs, w):\n",
    "    return np.power(data[trial+fs*4+w[0]: trial+fs*4+w[1],0],2).mean(), \\\n",
    "           np.power(data[trial+fs*4+w[0]: trial+fs*4+w[1],1],2).mean(), \\\n",
    "           np.power(data[trial+fs*4+w[0]: trial+fs*4+w[1],2],2).mean()\n",
    "\n",
    "def log_subBP_feature_extraction(alpha, beta, trials, fs, w):\n",
    "    # number of features combined for all trials\n",
    "    n_features = 15\n",
    "    # initialize the feature matrix\n",
    "    X = np.zeros((len(trials), n_features))\n",
    "    \n",
    "    # Extract features\n",
    "    for t, trial in enumerate(trials):\n",
    "        power_c31, power_c41, power_cz1 = powermean(alpha[0], trial, fs, w)\n",
    "        power_c32, power_c42, power_cz2 = powermean(alpha[1], trial, fs, w)\n",
    "        power_c33, power_c43, power_cz3 = powermean(alpha[2], trial, fs, w)\n",
    "        power_c34, power_c44, power_cz4 = powermean(alpha[3], trial, fs, w)\n",
    "        power_c31_b, power_c41_b, power_cz1_b = powermean(beta[0], trial, fs, w)\n",
    "        \n",
    "        X[t, :] = np.array(\n",
    "            [np.log(power_c31), np.log(power_c41), np.log(power_cz1),\n",
    "             np.log(power_c32), np.log(power_c42), np.log(power_cz2),\n",
    "             np.log(power_c33), np.log(power_c43), np.log(power_cz3), \n",
    "             np.log(power_c34), np.log(power_c44), np.log(power_cz4),\n",
    "             np.log(power_c31_b), np.log(power_c41_b), np.log(power_cz1_b)])\n",
    "\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if False:\n",
    "    w1 = [0,125]\n",
    "    w2 = [125,256]\n",
    "    w3 = [256,512]\n",
    "    w4 = [512,512+256]\n",
    "    \n",
    "    # extract the features\n",
    "    features1 = log_subBP_feature_extraction(\n",
    "        alpha_bands, beta_bands, \n",
    "        grazb_data.trials, grazb_data.sampling_freq,\n",
    "        w1)\n",
    "    features2 = log_subBP_feature_extraction(\n",
    "        alpha_bands, beta_bands, \n",
    "        grazb_data.trials, grazb_data.sampling_freq,\n",
    "        w2)     \n",
    "    features3 = log_subBP_feature_extraction(\n",
    "        alpha_bands, beta_bands, \n",
    "        grazb_data.trials, grazb_data.sampling_freq,\n",
    "        w3)        \n",
    "    features4 = log_subBP_feature_extraction(\n",
    "        alpha_bands, beta_bands, \n",
    "        grazb_data.trials, grazb_data.sampling_freq,\n",
    "        w4)        \n",
    "    print(features4.shape)\n",
    "\n",
    "    # concatenate and normalize the features\n",
    "    features = np.concatenate((features1, features2, features3, features4), axis=1)\n",
    "    features -= np.mean(features)\n",
    "    features = gumpy.signal.normalize(features, 'mean_std')\n",
    "    features = gumpy.signal.normalize(features, 'min_max')\n",
    "\n",
    "    # print shape to quickly check if everything is as expected\n",
    "    print(features.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Method 2: Discrete Wavelet Transform (DWT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dwt_features(data, trials, level, sampling_freq, w, n, wavelet): \n",
    "    import pywt\n",
    "    \n",
    "    # number of features per trial\n",
    "    n_features = 9 \n",
    "    # allocate memory to store the features\n",
    "    X = np.zeros((len(trials), n_features))\n",
    "\n",
    "    # Extract Features\n",
    "    for t, trial in enumerate(trials):\n",
    "        signals = data[trial + fs*4 + (w[0]) : trial + fs*4 + (w[1])]\n",
    "        coeffs_c3 = pywt.wavedec(data = signals[:,0], wavelet=wavelet, level=level)\n",
    "        coeffs_c4 = pywt.wavedec(data = signals[:,1], wavelet=wavelet, level=level)\n",
    "        coeffs_cz = pywt.wavedec(data = signals[:,2], wavelet=wavelet, level=level)\n",
    "\n",
    "        X[t, :] = np.array([\n",
    "            np.std(coeffs_c3[n]), np.mean(coeffs_c3[n]**2),  \n",
    "            np.std(coeffs_c4[n]), np.mean(coeffs_c4[n]**2),\n",
    "            np.std(coeffs_cz[n]), np.mean(coeffs_cz[n]**2), \n",
    "            np.mean(coeffs_c3[n]),\n",
    "            np.mean(coeffs_c4[n]), \n",
    "            np.mean(coeffs_cz[n])])\n",
    "        \n",
    "    return X\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We'll work with the data that was postprocessed using a butter bandpass\n",
    "# filter further above\n",
    "\n",
    "# to see it work, enable here. We'll use the log-power features further\n",
    "# below, though\n",
    "if False:\n",
    "    w = [0, 256]\n",
    "    \n",
    "    # extract the features\n",
    "    trials = grazb_data.trials\n",
    "    fs = grazb_data.sampling_freq\n",
    "    features1= np.array(dwt_features(btr_data, trials, 5, fs, w, 3, \"db4\"))\n",
    "    features2= np.array(dwt_features(btr_data, trials, 5, fs, w, 4, \"db4\"))\n",
    "\n",
    "    # concatenate and normalize the features \n",
    "    features = np.concatenate((features1, features2), axis=1)\n",
    "    features -= np.mean(features)\n",
    "    features = gumpy.signal.normalize(features, 'min_max')\n",
    "    \n",
    "    print(features.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Split the data\n",
    "Now that we extracted features (and reduced the dimensionality), we can split the data for \n",
    "test and training purposes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# gumpy exposes several methods to split a dataset, as shown in the examples:\n",
    "if 1: \n",
    "    split_features = np.array(gumpy.split.normal(features, labels,test_size=0.2))\n",
    "    X_train = split_features[0]\n",
    "    X_test = split_features[1]\n",
    "    Y_train = split_features[2]\n",
    "    Y_test = split_features[3]\n",
    "    X_train.shape\n",
    "if 0: \n",
    "    n_splits=5\n",
    "    split_features = np.array(gumpy.split.time_series_split(features, labels, n_splits)) \n",
    "if 0: \n",
    "    split_features = np.array(gumpy.split.normal(PCA, labels, test_size=0.2))\n",
    "    \n",
    "#ShuffleSplit: Random permutation cross-validator \n",
    "if 0: \n",
    "    split_features = gumpy.split.shuffle_Split(features, labels, n_splits=10,test_size=0.2,random_state=0)\n",
    "    \n",
    "# #Stratified K-Folds cross-validator\n",
    "# #Stratification is the process of rearranging the data as to ensure each fold is a good representative of the whole   \n",
    "if 0: \n",
    "    split_features = gumpy.split.stratified_KFold(features, labels, n_splits=3)\n",
    "    \n",
    "#Stratified ShuffleSplit cross-validator   \n",
    "#Repeated random sub-sampling validation\n",
    "if 0: \n",
    "    split_features = gumpy.split.stratified_shuffle_Split(features, labels, n_splits=10,test_size=0.3,random_state=0)\n",
    "\n",
    "\n",
    "# # the functions return a list with the data according to the following example\n",
    "# X_train = split_features[0]\n",
    "# X_test = split_features[1]\n",
    "# Y_train = split_features[2]\n",
    "# Y_test = split_features[3]\n",
    "# X_train.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extract features considering class information"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Method 3: Common Spatial Patterns (CSP)\n",
    "#### I - Common prerequesites"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_features(epoch, spatial_filter):\n",
    "    feature_matrix = np.dot(spatial_filter, epoch)\n",
    "    variance = np.var(feature_matrix, axis=1)\n",
    "    if np.all(variance == 0):\n",
    "        return np.zeros(spatial_filter.shape[0])\n",
    "    features = np.log(variance/np.sum(variance))\n",
    "    return features\n",
    "\n",
    "def select_filters(filters, num_components = None, verbose = False):\n",
    "    if verbose:\n",
    "        print(\"Incomming filters:\", filters.shape, \"\\nNumber of components:\", num_components if num_components else \" all\")\n",
    "    if num_components == None:\n",
    "        return filters\n",
    "    assert num_components <= filters.shape[0]/2, \"The requested number of components is too high\"\n",
    "    selection = list(range(0, num_components)) + list(range(filters.shape[0] - num_components, filters.shape[0]))\n",
    "    reduced_filters = filters[selection,:]\n",
    "    return reduced_filters\n",
    "\n",
    "# Select the number of used spatial components\n",
    "n_components = None; # assign None for all components\n",
    "\n",
    "# Rearrange epochs to (trials x channels x timesteps)\n",
    "epochs1 = np.swapaxes(np.array([epochs[0],epochs[1],epochs[2]]),1,0)\n",
    "epochs2 = np.swapaxes(np.array([epochs[3],epochs[4],epochs[5]]),1,0)\n",
    "print(\"Number of trials per class:\",epochs1.shape[0],\"|\",epochs2.shape[0])\n",
    "#print(epochs1.shape)\n",
    "#print(epochs2.shape)\n",
    "\n",
    "# Remove invalid epochs\n",
    "invalid_entries_1 = np.where([np.all(epoch == 0) for epoch in epochs1])\n",
    "invalid_entries_2 = np.where([np.all(epoch == 0) for epoch in epochs2])\n",
    "print(\"Number of invalid trials per class:\",len(invalid_entries_1[0]),\"|\",len(invalid_entries_2[0]))\n",
    "#print(invalid_entries_1)\n",
    "#print(invalid_entries_2)\n",
    "epochs1 = np.delete(epochs1, invalid_entries_1, 0)\n",
    "epochs2 = np.delete(epochs2, invalid_entries_2, 0)\n",
    "print(\"Number of trials per class after cleanup:\",epochs1.shape[0],\"|\",epochs2.shape[0])\n",
    "\n",
    "# Concatenate the trials\n",
    "epochs_re = np.concatenate((epochs1, epochs2), axis=0)\n",
    "print(\"Dataset:\", epochs_re.shape)\n",
    "\n",
    "# Update the label vector\n",
    "labels = np.ones(epochs_re.shape[0])\n",
    "labels[:epochs1.shape[0]] = 0\n",
    "\n",
    "# Split data\n",
    "epochs_train, epochs_test, y_train, y_test = gumpy.split.stratified_shuffle_Split(epochs_re, labels, n_splits=10,test_size=0.2,random_state=0)\n",
    "print(\"Training data:\", epochs_train.shape, \"Testing data:\", epochs_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### II A - Standard implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if False:    \n",
    "    # Generate the spacial filterns for the training data\n",
    "    temp_filters = np.asarray(gumpy.features.CSP(epochs_train[np.where(y_train==0)], epochs_train[np.where(y_train==1)]))\n",
    "    #print(temp_filters.shape)\n",
    "    spatial_filter = select_filters(temp_filters[0], n_components)\n",
    "    #print(spatial_filter.shape)\n",
    "    \n",
    "    # Extract the CSPs\n",
    "    features_train = np.array([extract_features(epoch, spatial_filter) for epoch in epochs_train])\n",
    "    #print(features_train.shape) \n",
    "    \n",
    "    features_test = np.array([extract_features(epoch, spatial_filter) for epoch in epochs_test])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### II B - Filter Bank CSP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if True:\n",
    "    # Apply the filter bank\n",
    "    freqs = [[1,4],[4,8],[8,13],[13,22],[22,30]] # delta, theta, alpha, low beta, high beta    \n",
    "    #freqs = [[1,4],[4,8],[8,13],[13,22],[22,30],[1,30]]\n",
    "    #freqs = [[1,4],[4,8],[8,12],[12,16],[16,20],[20,24],[24,28],[28,32]]\n",
    "    #freqs = [[1,3],[3,5],[5,7],[7,9],[9,11],[11,13],[13,15],[15,17],[17,19],[19,21],[21,23],[23,25],[25,27],[27,29],[29,31],[31,33],[33,35]]\n",
    "    x_train_fb = [gumpy.signal.butter_bandpass(epochs_train, lo=f[0], hi=f[1]) for f in freqs]\n",
    "    x_test_fb = [gumpy.signal.butter_bandpass(epochs_test, lo=f[0], hi=f[1]) for f in freqs]\n",
    "\n",
    "    # Generate the spatial filters for the training data\n",
    "    temp_filters = [np.asarray(gumpy.features.CSP(x_train_fb[f][np.where(y_train==0)], x_train_fb[f][np.where(y_train==1)])) for f in range(len(freqs))]\n",
    "    if n_components is not None:\n",
    "        spatial_filters = [select_filters(temp_filters[f][0], n_components) for f in range(len(freqs))]\n",
    "    else:\n",
    "        spatial_filters = [temp_filters[f][0] for f in range(len(freqs))]\n",
    "\n",
    "    # Extract the CSPs\n",
    "    features_train_tmp = [np.array([extract_features(epoch, spatial_filters[f]) for epoch in x_train_fb[f]]) for f in range(len(freqs))]\n",
    "    features_train = np.concatenate(features_train_tmp, axis=1)\n",
    "    features_test_tmp = [np.array([extract_features(epoch, spatial_filters[f]) for epoch in x_test_fb[f]]) for f in range(len(freqs))]\n",
    "    features_test = np.concatenate(features_test_tmp, axis=1)\n",
    "\n",
    "    print(features_train.shape)\n",
    "    print(features_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### III - Common postprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if True:\n",
    "    # Feature normalization\n",
    "    features_train = gumpy.signal.normalize(features_train, 'mean_std')\n",
    "    features_test = gumpy.signal.normalize(features_test, 'mean_std')\n",
    "    X_train = gumpy.signal.normalize(features_train, 'min_max')\n",
    "    X_test = gumpy.signal.normalize(features_test, 'min_max')\n",
    "    Y_train = y_train\n",
    "    Y_test = y_test\n",
    "    \n",
    "# Debugging output\n",
    "#usable_epochs = np.concatenate((epochs1, epochs2), axis=0)\n",
    "#print(usable_epochs.shape)\n",
    "#print(features.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Select features "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sequential Feature Selection Algorithm\n",
    "``gumpy`` provides a generic function with which you can select features. For a list of the implemented selectors please have a look at the function documentation.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if False:\n",
    "    Results = []\n",
    "    classifiers = []\n",
    "    Accuracy=[]\n",
    "    Final_results = {}\n",
    "    for model in gumpy.classification.available_classifiers:\n",
    "        print (model)\n",
    "        feature_idx, cv_scores, algorithm,sfs, clf = gumpy.features.sequential_feature_selector(X_train, Y_train, model,(6, 10), 5, 'SFFS')\n",
    "        classifiers.append(model)\n",
    "        Accuracy.append (cv_scores*100)\n",
    "        Final_results[model]= cv_scores*100\n",
    "        print (Final_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PCA \n",
    "``gumpy`` provides a wrapper around sklearn to reduce the dimensionality via PCA in a straightfoward manner."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PCA = gumpy.features.PCA_dim_red(features, 0.95)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plotting features\n",
    "``gumpy`` wraps 3D plotting of features into a single line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " gumpy.plot.PCA(\"3D\", features, split_features[0], split_features[2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Validation and test accuracies "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#SVM, RF, KNN, NB, LR, QLDA, LDA\n",
    "from sklearn.cross_validation import cross_val_score\n",
    "feature_idx, cv_scores, algorithm,sfs, clf = gumpy.features.sequential_feature_selector(X_train, Y_train, 'RandomForest',(6, 10), 10, 'SFFS')\n",
    "\n",
    "feature=X_train[:,feature_idx]\n",
    "scores = cross_val_score(clf, feature, Y_train, cv=10)\n",
    "\n",
    "\n",
    "print(\"Validation Accuracy: %0.2f (+/- %0.2f)\" % (scores.mean(), scores.std()))\n",
    "clf.fit(feature, Y_train)\n",
    "feature1=X_test[:,feature_idx]\n",
    "clf.predict(feature1)\n",
    "f=clf.score(feature1, Y_test)\n",
    "print(\"Test Accuracy:\",f )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Voting classifiers "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because `gumpy.classification.vote` uses `sklearn.ensemble.VotingClassifier` as backend, it is possible to specify different methods for the voting such as 'soft'. In addition, the method can be told to first extract features via `mlxtend.feature_selection.SequentialFeatureSelector` before classification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if True:\n",
    "    result, _ = gumpy.classification.vote(X_train, Y_train, X_test, Y_test, 'soft', False, (6,12))\n",
    "    print(\"Classification result for hard voting classifier\")\n",
    "    print(result)\n",
    "    print(\"Accuracy: \", result.accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if True:\n",
    "    result, _ = gumpy.classification.vote(X_train, Y_train, X_test, Y_test, 'hard', False, (6,12))\n",
    "    print(\"Classification result for hard voting classifier\")\n",
    "    print(result)\n",
    "    print(\"Accuracy: \", result.accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Voting Classifier with feature selection\n",
    "`gumpy` allows to automatically use all classifiers that are known in `gumpy.classification.available_classifiers` in a voting classifier (for more details see `sklearn.ensemble.VotingClassifier`). In case you developed a custom classifier and registered it using the `@register_classifier` decorator, it will be automatically used as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result, _ = gumpy.classification.vote(X_train, Y_train, X_test, Y_test, 'soft', True, (6,12))\n",
    "print(\"Classification result for soft voting classifier\")\n",
    "print(result)\n",
    "print(\"Accuracy: \", result.accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classification without feature selection "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if False:\n",
    "    Results = []\n",
    "    classifiers = []\n",
    "    Accuracy=[]\n",
    "    Final_results = {}\n",
    "    for model in gumpy.classification.available_classifiers:\n",
    "        results, clf = gumpy.classify(model, X_train, Y_train, X_test, Y_test)\n",
    "        print (model)\n",
    "        print (results)\n",
    "        classifiers.append(model)\n",
    "        Accuracy.append (results.accuracy) \n",
    "        Final_results[model]= results.accuracy\n",
    "    print (Final_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Confusion Matrix\n",
    "One of the ideas behind ``gumpy`` is to provide users the means to quickly examine their data. Therefore, gumpy provides mostly wraps existing libraries. This allows to show data with ease, and still be able to modify the plots in any way the underlying libraries allow:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Method 1\n",
    "gumpy.plot.confusion_matrix(Y_test, results.pred)\n",
    "\n",
    "#Method 2\n",
    "gumpy.plot. plot_confusion_matrix(path='...', cm= , normalize = False, target_names = ['...' ], title        = \"...\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
